{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "263ec265",
   "metadata": {},
   "source": [
    "[![Dataflowr](https://raw.githubusercontent.com/dataflowr/website/master/_assets/dataflowr_logo.png)](https://dataflowr.github.io/website/)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96c28fae",
   "metadata": {},
   "source": [
    "# [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) for MNIST\n",
    "(J. Ho, A. Jain, P. Abbeel 2020)\n",
    "\n",
    "![](https://raw.githubusercontent.com/dataflowr/website/master/modules/extras/diffusions/ddpm.png)\n",
    "\n",
    "\n",
    "Given a schedule $\\beta_1<\\beta_2<\\dots <\\beta_T$, the **forward diffusion process** is defined by:\n",
    "$q(x_t|x_{t-1}) = \\mathcal{N}(x_t; \\sqrt{1-\\beta_t}x_{t-1},\\beta_t I)$ and $q(x_{1:T}|x_0) = \\prod_{t=1}^T q(x_t|x_{t-1})$.\n",
    "\n",
    "With $\\alpha_t = 1-\\beta_t$ and $\\overline{\\alpha_t} = \\prod_{i=1}^t\\alpha_i$, we see that, with $\\epsilon\\sim\\mathcal{N}(0,I)$:\n",
    "\\begin{align*}\n",
    "x_t = \\sqrt{\\overline{\\alpha}_t}x_0 + \\sqrt{1-\\overline{\\alpha}_t}\\epsilon.\n",
    "\\end{align*}\n",
    "The law $q(x_{t-1}|x_t,\\epsilon)$ is explicit: $q(x_{t-1}|x_t,\\epsilon) = \\mathcal{N}(x_{t-1};\\mu(x_t,\\epsilon,t), \\gamma_t I)$ with,\n",
    "\\begin{align*}\n",
    "\\mu(x_t,\\epsilon, t) = \\frac{1}{\\sqrt{\\alpha_t}}\\left( x_t-\\frac{1-\\alpha_t}{\\sqrt{1-\\overline{\\alpha}_t}}\\epsilon\\right)\\text{ and, }\n",
    "\\gamma_t = \\frac{1-\\overline{\\alpha}_{t-1}}{1-\\overline{\\alpha}_{t}}\\beta_t\n",
    "\\end{align*}\n",
    "\n",
    "\n",
    "**Training**: to approximate **the reversed diffusion** $q(x_{t-1}|x_t)$ by a neural network given by $p_{\\theta}(x_{t-1}|x_t) = \\mathcal{N}(x_{t-1}; \\mu_{\\theta}(x_t,t), \\beta_t I)$ and $p(x_T) \\sim \\mathcal{N}(0,I)$, we maximize the usual Variational bound:\n",
    "\\begin{align*}\n",
    "\\mathbb{E}_{q(x_0)} \\ln p_{\\theta}(x_0) &\\geq L_T +\\sum_{t=2}^T L_{t-1}+L_0 \\text{ with, }L_{t-1} = \\mathbb{E}_q\\left[ \\frac{1}{2\\sigma_t^2}\\|\\mu_\\theta(x_t,t) -\\mu(x_t,\\epsilon,t)\\|^2\\right].\n",
    "\\end{align*}\n",
    "With the change of variable:\n",
    "\\begin{align*}\n",
    "\\mu_\\theta(x_t,t) = \\frac{1}{\\sqrt{\\alpha_t}}\\left( x_t-\\frac{1-\\alpha_t}{\\sqrt{1-\\overline{\\alpha}_t}}\\epsilon_\\theta(x_t,t)\\right),\n",
    "\\end{align*}\n",
    "ignoring the prefactor and sampling $\\tau$ instead of summing over all $t$, the loss is finally:\n",
    "\\begin{align*}\n",
    "\\ell(\\theta) = \\mathbb{E}_\\tau\\mathbb{E}_\\epsilon \\left[ \\|\\epsilon - \\epsilon_\\theta(\\sqrt{\\overline{\\alpha}_\\tau}x_0 + \\sqrt{1-\\overline{\\alpha}_\\tau}\\epsilon, \\tau)\\|^2\\right]\n",
    "\\end{align*}\n",
    "\n",
    "\n",
    "\n",
    "**Sampling**: to simulate the reversed diffusion with the learned $\\epsilon_\\theta(x_t,t)$ starting from $x_T\\sim \\mathcal{N}(0,I)$, iterate for $t=T,\\dots, 1$:\n",
    "\\begin{align*}\n",
    "x_{t-1} = \\frac{1}{\\sqrt{\\alpha_t}}\\left( x_t-\\frac{1-\\alpha_t}{\\sqrt{1-\\overline{\\alpha}_t}}\\epsilon_\\theta(x_t,t)\\right)+\\sqrt{\\beta_t}\\epsilon,\\text{ with } \\epsilon\\sim\\mathcal{N}(0,I).\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "934767ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import MNIST\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0282585f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_images(images, title=\"\"):\n",
    "    \"\"\"Shows the provided images as sub-pictures in a square\"\"\"\n",
    "    images = [im.permute(1,2,0).numpy() for im in images]\n",
    "\n",
    "    # Defining number of rows and columns\n",
    "    fig = plt.figure(figsize=(8, 8))\n",
    "    rows = int(len(images) ** (1 / 2))\n",
    "    cols = round(len(images) / rows)\n",
    "\n",
    "    # Populating figure with sub-plots\n",
    "    idx = 0\n",
    "    for r in range(rows):\n",
    "        for c in range(cols):\n",
    "            fig.add_subplot(rows, cols, idx + 1)\n",
    "\n",
    "            if idx < len(images):\n",
    "                plt.imshow(images[idx], cmap=\"gray\")\n",
    "                plt.axis('off')\n",
    "                idx += 1\n",
    "    fig.suptitle(title, fontsize=30)\n",
    "    \n",
    "    # Showing the figure\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bec9075",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "713de2f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sinusoidal_embedding(n, d):\n",
    "    # Returns the standard positional embedding\n",
    "    embedding = torch.tensor([[i / 10_000 ** (2 * j / d) for j in range(d)] for i in range(n)])\n",
    "    sin_mask = torch.arange(0, n, 2)\n",
    "\n",
    "    embedding[sin_mask] = torch.sin(embedding[sin_mask])\n",
    "    embedding[1 - sin_mask] = torch.cos(embedding[sin_mask])\n",
    "\n",
    "    return embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "968c102d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyConv(nn.Module):\n",
    "    def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, padding=1, activation=None, normalize=True):\n",
    "        super(MyConv, self).__init__()\n",
    "        self.ln = nn.LayerNorm(shape)\n",
    "        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size, stride, padding)\n",
    "        self.activation = nn.SiLU() if activation is None else activation\n",
    "        self.normalize = normalize\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = self.ln(x) if self.normalize else x\n",
    "        out = self.conv1(out)\n",
    "        out = self.activation(out)\n",
    "        return out\n",
    "    \n",
    "def MyTinyBlock(size, in_c, out_c):\n",
    "    return nn.Sequential(MyConv((in_c, size, size), in_c, out_c), \n",
    "                         MyConv((out_c, size, size), out_c, out_c), \n",
    "                         MyConv((out_c, size, size), out_c, out_c))\n",
    "\n",
    "def MyTinyUp(size, in_c):\n",
    "    return nn.Sequential(MyConv((in_c, size, size), in_c, in_c//2), \n",
    "                         MyConv((in_c//2, size, size), in_c//2, in_c//4), \n",
    "                         MyConv((in_c//4, size, size), in_c//4, in_c//4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6a33a80",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyTinyUNet(nn.Module):\n",
    "  # Here is a network with 3 down and 3 up with the tiny block\n",
    "    def __init__(self, in_c=1, out_c=1, size=32, n_steps=1000, time_emb_dim=100):\n",
    "        super(MyTinyUNet, self).__init__()\n",
    "\n",
    "        # Sinusoidal embedding\n",
    "        self.time_embed = nn.Embedding(n_steps, time_emb_dim)\n",
    "        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)\n",
    "        self.time_embed.requires_grad_(False)\n",
    "\n",
    "        # First half\n",
    "        self.te1 = self._make_te(time_emb_dim, 1)\n",
    "        self.b1 = MyTinyBlock(size, in_c, 10)\n",
    "        self.down1 = nn.Conv2d(10, 10, 4, 2, 1)\n",
    "        self.te2 = self._make_te(time_emb_dim, 10)\n",
    "        self.b2 = MyTinyBlock(size//2, 10, 20)\n",
    "        self.down2 = nn.Conv2d(20, 20, 4, 2, 1)\n",
    "        self.te3 = self._make_te(time_emb_dim, 20)\n",
    "        self.b3 = MyTinyBlock(size//4, 20, 40)\n",
    "        self.down3 = nn.Conv2d(40, 40, 4, 2, 1)\n",
    "\n",
    "        # Bottleneck\n",
    "        self.te_mid = self._make_te(time_emb_dim, 40)\n",
    "        self.b_mid = nn.Sequential(\n",
    "            MyConv((40, size//8, size//8), 40, 20),\n",
    "            MyConv((20, size//8, size//8), 20, 20),\n",
    "            MyConv((20, size//8, size//8), 20, 40)\n",
    "        )\n",
    "\n",
    "        # Second half\n",
    "        self.up1 = nn.ConvTranspose2d(40, 40, 4, 2, 1)\n",
    "        self.te4 = self._make_te(time_emb_dim, 80)\n",
    "        self.b4 = MyTinyUp(size//4, 80)\n",
    "        self.up2 = nn.ConvTranspose2d(20, 20, 4, 2, 1)\n",
    "        self.te5 = self._make_te(time_emb_dim, 40)\n",
    "        self.b5 = MyTinyUp(size//2, 40)\n",
    "        self.up3 = nn.ConvTranspose2d(10, 10, 4, 2, 1)\n",
    "        self.te_out = self._make_te(time_emb_dim, 20)\n",
    "        self.b_out = MyTinyBlock(size, 20, 10)\n",
    "        self.conv_out = nn.Conv2d(10, out_c, 3, 1, 1)\n",
    "\n",
    "    def forward(self, x, t): # x is (bs, in_c, size, size) t is (bs)\n",
    "        t = self.time_embed(t)\n",
    "        n = len(x)\n",
    "        out1 = self.b1(x + self.te1(t).reshape(n, -1, 1, 1))  # (bs, 10, size/2, size/2)\n",
    "        out2 = self.b2(self.down1(out1) + self.te2(t).reshape(n, -1, 1, 1))  # (bs, 20, size/4, size/4)\n",
    "        out3 = self.b3(self.down2(out2) + self.te3(t).reshape(n, -1, 1, 1))  # (bs, 40, size/8, size/8)\n",
    "\n",
    "        out_mid = self.b_mid(self.down3(out3) + self.te_mid(t).reshape(n, -1, 1, 1))  # (bs, 40, size/8, size/8)\n",
    "\n",
    "        out4 = torch.cat((out3, self.up1(out_mid)), dim=1)  # (bs, 80, size/8, size/8)\n",
    "        out4 = self.b4(out4 + self.te4(t).reshape(n, -1, 1, 1))  # (bs, 20, size/8, size/8)\n",
    "        out5 = torch.cat((out2, self.up2(out4)), dim=1)  # (bs, 40, size/4, size/4)\n",
    "        out5 = self.b5(out5 + self.te5(t).reshape(n, -1, 1, 1))  # (bs, 10, size/2, size/2)\n",
    "        out = torch.cat((out1, self.up3(out5)), dim=1)  # (bs, 20, size, size)\n",
    "        out = self.b_out(out + self.te_out(t).reshape(n, -1, 1, 1))  # (bs, 10, size, size)\n",
    "        out = self.conv_out(out) # (bs, out_c, size, size)\n",
    "        return out\n",
    "\n",
    "    def _make_te(self, dim_in, dim_out):\n",
    "        return nn.Sequential(nn.Linear(dim_in, dim_out), nn.SiLU(), nn.Linear(dim_out, dim_out))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d71d57f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 3\n",
    "x = torch.randn(bs,1,32,32)\n",
    "n_steps=1000\n",
    "timesteps = torch.randint(0, n_steps, (bs,)).long()\n",
    "unet = MyTinyUNet(in_c =1, out_c =1, size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d1762df",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = unet(x,timesteps)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4c37c98",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DDPM(nn.Module):\n",
    "    def __init__(self, network, num_timesteps, beta_start=0.0001, beta_end=0.02, device=device) -> None:\n",
    "        super(DDPM, self).__init__()\n",
    "        self.num_timesteps = num_timesteps\n",
    "        self.betas = torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32).to(device)\n",
    "        self.alphas = 1.0 - self.betas\n",
    "        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)\n",
    "        self.network = network\n",
    "        self.device = device\n",
    "        self.sqrt_alphas_cumprod = self.alphas_cumprod ** 0.5 # used in add_noise\n",
    "        self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod) ** 0.5 # used in add_noise and step\n",
    "\n",
    "    def add_noise(self, x_start, x_noise, timesteps):\n",
    "        # The forward process\n",
    "        # x_start and x_noise (bs, n_c, w, d)\n",
    "        # timesteps (bs)\n",
    "        s1 = self.sqrt_alphas_cumprod[timesteps] # bs\n",
    "        s2 = self.sqrt_one_minus_alphas_cumprod[timesteps] # bs\n",
    "        s1 = s1.reshape(-1,1,1,1) # (bs, 1, 1, 1) for broadcasting\n",
    "        s2 = s2.reshape(-1,1,1,1) # (bs, 1, 1, 1)\n",
    "        return s1 * x_start + s2 * x_noise\n",
    "\n",
    "    def reverse(self, x, t):\n",
    "        # The network return the estimation of the noise we added\n",
    "        return self.network(x, t)\n",
    "    \n",
    "    def step(self, model_output, timestep, sample):\n",
    "        # one step of sampling\n",
    "        # timestep (1)\n",
    "        t = timestep\n",
    "        coef_epsilon = (1-self.alphas)/self.sqrt_one_minus_alphas_cumprod\n",
    "        coef_eps_t = coef_epsilon[t].reshape(-1,1,1,1)\n",
    "        coef_first = 1/self.alphas ** 0.5\n",
    "        coef_first_t = coef_first[t].reshape(-1,1,1,1)\n",
    "        pred_prev_sample = coef_first_t*(sample-coef_eps_t*model_output)\n",
    "\n",
    "        variance = 0\n",
    "        if t > 0:\n",
    "            noise = torch.randn_like(model_output).to(self.device)\n",
    "            variance = ((self.betas[t] ** 0.5) * noise)\n",
    "            \n",
    "        pred_prev_sample = pred_prev_sample + variance\n",
    "\n",
    "        return pred_prev_sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a730dcf",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_timesteps = 1000\n",
    "betas = torch.linspace(0.0001, 0.02, num_timesteps, dtype=torch.float32).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "818ce5f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "betas[timesteps]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9905682a",
   "metadata": {},
   "outputs": [],
   "source": [
    "betas[10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47645019",
   "metadata": {},
   "outputs": [],
   "source": [
    "betas[timesteps].reshape(-1,1,1,1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29c69328",
   "metadata": {},
   "outputs": [],
   "source": [
    "network = MyTinyUNet(in_c =1, out_c =1, size=32)\n",
    "model = DDPM(network, num_timesteps, beta_start=0.0001, beta_end=0.02, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8b71af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 5\n",
    "x = torch.randn(bs,1,32,32).to(device)\n",
    "timesteps = 10*torch.ones(bs,).long().long().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90480e3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "timesteps.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fac7fd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = model.add_noise(x,x,timesteps)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78d92db7",
   "metadata": {},
   "outputs": [],
   "source": [
    "y = model.step(x,timesteps[0],x)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb057523",
   "metadata": {},
   "source": [
    "You can check that all the parameters of the UNet `network` are indeed parameters of the DDPM `model` like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a28374b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for n, p in model.named_parameters():\n",
    "    print(n, p.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef4f6222",
   "metadata": {},
   "outputs": [],
   "source": [
    "def training_loop(model, dataloader, optimizer, num_epochs, num_timesteps, device=device):\n",
    "    \"\"\"Training loop for DDPM\"\"\"\n",
    "\n",
    "    global_step = 0\n",
    "    losses = []\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        progress_bar = tqdm(total=len(dataloader))\n",
    "        progress_bar.set_description(f\"Epoch {epoch}\")\n",
    "        for step, batch in enumerate(dataloader):\n",
    "            batch = batch[0].to(device)\n",
    "            noise = torch.randn(batch.shape).to(device)\n",
    "            timesteps = torch.randint(0, num_timesteps, (batch.shape[0],)).long().to(device)\n",
    "\n",
    "            noisy = model.add_noise(batch, noise, timesteps)\n",
    "            noise_pred = model.reverse(noisy, timesteps)\n",
    "            loss = F.mse_loss(noise_pred, noise)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            progress_bar.update(1)\n",
    "            logs = {\"loss\": loss.detach().item(), \"step\": global_step}\n",
    "            losses.append(loss.detach().item())\n",
    "            progress_bar.set_postfix(**logs)\n",
    "            global_step += 1\n",
    "        \n",
    "        progress_bar.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd6ab06c",
   "metadata": {},
   "outputs": [],
   "source": [
    "root_dir = './data/'\n",
    "transform01 = torchvision.transforms.Compose([\n",
    "        torchvision.transforms.Resize(32),\n",
    "        torchvision.transforms.ToTensor(),\n",
    "        torchvision.transforms.Normalize((0.5), (0.5))\n",
    "    ])\n",
    "dataset = torchvision.datasets.MNIST(root=root_dir, train=True, transform=transform01, download=True)\n",
    "dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=4096, shuffle=True, num_workers=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf8479b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "for b in dataloader:\n",
    "    batch = b[0]\n",
    "    break\n",
    "\n",
    "bn = [b for b in batch[:100]] \n",
    "show_images(bn, \"origin\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed5a9196",
   "metadata": {},
   "outputs": [],
   "source": [
    "learning_rate = 1e-3\n",
    "num_epochs = 50\n",
    "num_timesteps = 1000\n",
    "network = MyTinyUNet()\n",
    "network = network.to(device)\n",
    "model = DDPM(network, num_timesteps, beta_start=0.0001, beta_end=0.02, device=device)\n",
    "optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)\n",
    "training_loop(model, dataloader, optimizer, num_epochs, num_timesteps, device=device)        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a2cd2e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_image(ddpm, sample_size, channel, size):\n",
    "    \"\"\"Generate the image from the Gaussian noise\"\"\"\n",
    "\n",
    "    frames = []\n",
    "    frames_mid = []\n",
    "    ddpm.eval()\n",
    "    with torch.no_grad():\n",
    "        timesteps = list(range(ddpm.num_timesteps))[::-1]\n",
    "        sample = torch.randn(sample_size, channel, size, size).to(device)\n",
    "        \n",
    "        for i, t in enumerate(tqdm(timesteps)):\n",
    "            time_tensor = (torch.ones(sample_size,1) * t).long().to(device)\n",
    "            residual = ddpm.reverse(sample, time_tensor)\n",
    "            sample = ddpm.step(residual, time_tensor[0], sample)\n",
    "\n",
    "            if t==500:\n",
    "                for i in range(sample_size):\n",
    "                    frames_mid.append(sample[i].detach().cpu())\n",
    "\n",
    "        for i in range(sample_size):\n",
    "            frames.append(sample[i].detach().cpu())\n",
    "    return frames, frames_mid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2ac7e7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated, generated_mid = generate_image(model, 100, 1, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d231ac1",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_images(generated_mid, \"Mid result\")\n",
    "show_images(generated, \"Final result\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dd449fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rescale(x):\n",
    "    return (x+1)/2\n",
    "\n",
    "def show_images_rescale(images, title=\"\"):\n",
    "    \"\"\"Shows the provided images as sub-pictures in a square\"\"\"\n",
    "    images = [rescale((im.permute(1,2,0)).numpy()) for im in images]\n",
    "\n",
    "    # Defining number of rows and columns\n",
    "    fig = plt.figure(figsize=(8, 8))\n",
    "    rows = int(len(images) ** (1 / 2))\n",
    "    cols = round(len(images) / rows)\n",
    "\n",
    "    # Populating figure with sub-plots\n",
    "    idx = 0\n",
    "    for r in range(rows):\n",
    "        for c in range(cols):\n",
    "            fig.add_subplot(rows, cols, idx + 1)\n",
    "\n",
    "            if idx < len(images):\n",
    "                #plt.imshow(images[idx].reshape(pixel, pixel, n_channels), cmap=\"gray\")\n",
    "                plt.imshow(images[idx])\n",
    "                plt.axis('off')\n",
    "                idx += 1\n",
    "    fig.suptitle(title, fontsize=30)\n",
    "    \n",
    "    # Showing the figure\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72a38a96",
   "metadata": {},
   "outputs": [],
   "source": [
    "show_images_rescale(generated, \"Final result\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3b0c353",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
